{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67736954-415d-413c-bf5f-e55ea4505ce3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "import json\n",
    "import copy\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import random\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1535745b-f3f0-40b8-84dd-fe600f411afd",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Dataset Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a28757f-3262-4d53-867e-b2f5afffbf11",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_dataset_webnlg(pair_src, pair_tgt, graph, text):\n",
    "    wf_src = open(pair_src, 'a')\n",
    "    wf_tgt = open(pair_tgt, 'a')\n",
    "    \n",
    "    # y = 'Correct'\n",
    "    # wf_src.write(text + ' <S> ' + json.dumps(graph, ensure_ascii=False) + '\\n')\n",
    "    # wf_tgt.write(y + '\\n')\n",
    "    \n",
    "    for i in range(len(graph)):\n",
    "        x = copy.deepcopy(graph)\n",
    "\n",
    "        if len(graph) != 1:\n",
    "            y = x[i]\n",
    "            x.pop(i)\n",
    "            # x[i][0], x[i][2] = x[i][2], x[i][0]\n",
    "            wf_src.write(text + ' <S> ' + json.dumps(x, ensure_ascii=False) + '\\n')\n",
    "            wf_tgt.write(json.dumps(y, ensure_ascii=False) + '\\n')\n",
    "\n",
    "    wf_tgt.close()\n",
    "    wf_src.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a22fade5-7173-409d-ad48-93ea4208f1c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_dataset_kelm(pair_src, pair_tgt, graph, text):\n",
    "    wf_src = open(pair_src, 'a')\n",
    "    wf_tgt = open(pair_tgt, 'a')\n",
    "    \n",
    "    y = 'Correct'\n",
    "    wf_src.write(text + ' <S> ' + json.dumps(graph, ensure_ascii=False) + '\\n')\n",
    "    wf_tgt.write(y + '\\n')\n",
    "    \n",
    "    x = copy.deepcopy(graph)\n",
    "    if len(graph) != 1:\n",
    "        y = random.choice(x)\n",
    "        x.pop(x.index(y))\n",
    "        wf_src.write(text + ' <S> ' + json.dumps(x, ensure_ascii=False) + '\\n')\n",
    "        wf_tgt.write(json.dumps(y, ensure_ascii=False) + '\\n')\n",
    "\n",
    "    wf_tgt.close()\n",
    "    wf_src.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e36640a-e3e8-4eb5-bd19-30be45a9103a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_dataset_unseen(pair_src, pair_tgt, graph, text):\n",
    "    wf_src = open(pair_src, 'a')\n",
    "    wf_tgt = open(pair_tgt, 'a')\n",
    "    \n",
    "    # y = 'Correct'\n",
    "    # wf_src.write(text + ' <S> ' + json.dumps(graph, ensure_ascii=False) + '\\n')\n",
    "    # wf_tgt.write(y + '\\n')\n",
    "    \n",
    "    for i in range(len(graph)):\n",
    "        x = copy.deepcopy(graph)\n",
    "        \n",
    "        if len(graph) != 1:\n",
    "            if x[i][1][2:].strip().lower() not in text.lower() and x[i][2][2:].strip().lower() not in text.lower():\n",
    "                y = x[i][:-1]\n",
    "                x.pop(i)\n",
    "                wf_src.write(text + ' <S> ' + json.dumps(x, ensure_ascii=False) + '\\n')\n",
    "                wf_tgt.write(json.dumps(y, ensure_ascii=False) + '\\n')\n",
    "\n",
    "    wf_tgt.close()\n",
    "    wf_src.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c3a207c-ae7e-4757-9278-2a12539e5204",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_dataset_correct(pair_src, pair_tgt, graph, text):\n",
    "    wf_src = open(pair_src, 'a')\n",
    "    wf_tgt = open(pair_tgt, 'a')\n",
    "    \n",
    "    y = 'Correct'\n",
    "    wf_src.write(text + ' <S> ' + json.dumps(graph, ensure_ascii=False) + '\\n')\n",
    "    wf_tgt.write(y + '\\n')\n",
    "\n",
    "    wf_tgt.close()\n",
    "    wf_src.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "011b8e21-780a-48c9-9951-8621a9a8dc44",
   "metadata": {},
   "outputs": [],
   "source": [
    "pair_train_src = \"data/only_one_error_webnlg/train.source\"\n",
    "pair_train_tgt = \"data/only_one_error_webnlg/train.target\"\n",
    "\n",
    "with open(\"webnlg20/train.source\", 'r') as fa, open('webnlg20/train.target', 'r') as fb:\n",
    "    a = fa.readlines()\n",
    "    b = fb.readlines()\n",
    "    existing_graphs = []\n",
    "    for i in range(len(a)):\n",
    "        # change string to list\n",
    "        graph = ast.literal_eval(a[i].strip())\n",
    "        text = b[i].strip()\n",
    "        generate_dataset_correct(pair_train_src, pair_train_tgt, graph, text)\n",
    "        if graph not in existing_graphs:\n",
    "            generate_dataset_webnlg(pair_train_src, pair_train_tgt, graph, text)\n",
    "            # generate_dataset_unseen(pair_train_src, pair_train_tgt, graph, text)\n",
    "            existing_graphs.append(graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00d61d9a-2b29-40c4-8d31-3aaca4ae1add",
   "metadata": {},
   "outputs": [],
   "source": [
    "pair_train_src = \"data/only_one_error_kelm/train.source\"\n",
    "pair_train_tgt = \"data/only_one_error_kelm/train.target\"\n",
    "\n",
    "with open(\"data/kelm_subset/train.source\", 'r') as fa, open('data/kelm_subset/train.target', 'r') as fb:\n",
    "    a = fa.readlines()\n",
    "    b = fb.readlines()\n",
    "    for i in range(len(a)):\n",
    "        # change string to list\n",
    "        graph = ast.literal_eval(a[i].strip())\n",
    "        text = b[i].strip()\n",
    "        generate_dataset_kelm(pair_train_src, pair_train_tgt, graph, text) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a005214-469d-445a-bda3-97a548609c5a",
   "metadata": {},
   "source": [
    "# Post-Processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b0791fe-4f25-4f1f-be6f-466bb9a60ce2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create unified verifier input\n",
    "all_data = []\n",
    "with open(\"GPT3.5_result_KELM_6_shots/Iteration1/test_generated_graphs.txt\", 'r') as fa, open('GPT3.5_result_KELM_1_shots/test.target', 'r') as fb:\n",
    "    a = fa.readlines()\n",
    "    b = fb.readlines()\n",
    "    for i in range(len(a)):\n",
    "        data_dict = {}\n",
    "        graph = a[i].strip()\n",
    "        text = b[i].strip()\n",
    "        data_dict[\"instruction\"] = \"Predict the missing triple given the text and graph for KELM dataset.\"\n",
    "        data_dict[\"input\"] = text.replace('\"',\"'\") + ' <S> ' + str(graph).replace('\"',\"'\") \n",
    "        all_data.append(data_dict)\n",
    "with open('GPT3.5_result_KELM_6_shots/Iteration3/verifier_input.json','w',encoding='utf-8') as file:\n",
    "    file.write(json.dumps(all_data, indent=4, ensure_ascii=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8bfa632-de7c-4ab2-aa8c-efe73669ec74",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create single verifier input\n",
    "test_generated_graphs = []\n",
    "test_texts = []\n",
    "with open(\"GPT3.5_result_GenWiki/Iteration1/test_generated_graphs.txt\", 'r') as f:\n",
    "    for line in f.readlines():\n",
    "        test_generated_graphs.append(line.strip())\n",
    "\n",
    "with open(\"GPT3.5_result_GenWiki/test.target\", 'r') as f:\n",
    "    for line in f.readlines():\n",
    "        test_texts.append(line.strip())\n",
    "\n",
    "with open(\"GPT3.5_result_GenWiki/Iteration1/test.source\", 'w') as f:\n",
    "    for i in range(len(test_generated_graphs)):\n",
    "        f.write(test_texts[i] + ' <S> ' + test_generated_graphs[i] + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e46091ba-a744-4473-95ec-fb0f82690579",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Iterative Ofﬂine Correction\n",
    "full_test_generated_graphs = []\n",
    "full_verifier_texts = []\n",
    "with open(\"GPT3.5_result_KELM_6_shots/Iteration3/test_generated_graphs.txt\", 'r') as f:\n",
    "    for line in f.readlines():\n",
    "        full_test_generated_graphs.append(line.strip())\n",
    "\n",
    "with open(\"GPT3.5_result_KELM_6_shots/Iteration3/verifier_result.txt\", 'r') as f:\n",
    "    for line in f.readlines():\n",
    "        full_verifier_texts.append(line.strip())\n",
    "\n",
    "with open(\"GPT3.5_result_KELM_6_shots/Iteration4/test_generated_graphs.txt\", 'w') as f:\n",
    "    for i in range(len(full_test_generated_graphs)):\n",
    "        if full_verifier_texts[i] != 'Correct':\n",
    "            f.write(full_test_generated_graphs[i][:-1] + ', ' + full_verifier_texts[i] + ']' + '\\n')\n",
    "        else:\n",
    "            f.write(full_test_generated_graphs[i] + '\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e78b0be-ba00-45c1-b6f2-5abce556e580",
   "metadata": {},
   "source": [
    "# Calculate Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6cbc9ec-0ec4-448c-bee0-c6194100ed40",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cal_accuracy(verification_results):\n",
    "    correct_num = 0\n",
    "\n",
    "    for i in range(len(verification_results)):\n",
    "        if verification_results[i] == 'Correct':\n",
    "            correct_num += 1\n",
    "\n",
    "    print('Accuracy: ', correct_num/len(verification_results))    \n",
    "    return correct_num/len(verification_results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
